import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from Network.network import Network
from Network.network_utils import reduce_function
from Network.General.Flat.mlp import MLPNetwork
from Network.General.Conv.conv import ConvNetwork
from Network.General.Factor.Pair.pair import PairNetwork
from Network.General.Factor.factored import return_values


class MaskedMLPNetwork(Network):
    def __init__(self, args):
        '''
        Applies a single MLP to the concatenation of key, followed by queries
        applies a separate MLP to each object in order
        '''
        super().__init__(args)
        # assumes the input is flattened list of input space sized values
        # needs an object dim
        self.fp = args.factor
        self.embed_dim = args.output_dim
        self.aggregate_final = args.aggregate_final
        self.append_keys = args.factor_net.append_keys
        self.append_mask = args.factor_net.append_mask

        mlp_args = copy.deepcopy(args)
        mlp_args.num_inputs = (int(self.append_keys) * self.fp.first_obj_dim +  
                                    self.fp.object_dim * self.fp.num_objects + 
                                    int(args.append_mask) * self.fp.num_objects)
        mlp_args.num_outputs = args.output_dim
        self.mlp = MLPNetwork(mlp_args)
        self.train()
        self.reset_network_parameters()
    
    def forward(self, key, query, mask, ret_settings):
        embeddings = list()
        x = query
        if mask is not None:
            query = query * mask
        if self.append_keys:
            x = torch.cat([key.reshape(key.shape[0], -1), x.reshape(query.shape[0], -1)], axis=1)
        if self.append_mask:
            if mask is None: mask = torch.ones(x.shape, device=self.device)
            x = torch.cat([x, mask.reshape(mask.shape[0], -1)], axis = -1)
        x = self.mlp(x)
        return return_values(ret_settings, x, None, None)